{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from glob import glob\n",
    "import hashlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from numba import njit\n",
    "\n",
    "from dataset.common import inverse_dihedral_transform\n",
    "\n",
    "\n",
    "DATASET_PATH = \"data/arc-aug-1000\"  # ARC-1\n",
    "# DATASET_PATH = \"data/arc-2-aug-1000\"  # ARC-2\n",
    "\n",
    "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n",
    "\n",
    "\n",
    "PAD_PUZZLE_IDENTIFIER = 0\n",
    "\n",
    "# Visualization\n",
    "ARC_COLOR_MAP = mcolors.ListedColormap([\n",
    "    \"#000000\",  # symbol_0: black\n",
    "    \"#0074D9\",  # symbol_1: blue\n",
    "    \"#FF4136\",  # symbol_2: red\n",
    "    \"#2ECC40\",  # symbol_3: green\n",
    "    \"#FFDC00\",  # symbol_4: yellow\n",
    "    \"#AAAAAA\",  # symbol_5: grey\n",
    "    \"#F012BE\",  # symbol_6: fuschia\n",
    "    \"#FF851B\",  # symbol_7: orange\n",
    "    \"#7FDBFF\",  # symbol_8: teal\n",
    "    \"#870C25\"   # symbol_9: brown\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n",
    "    # Load puzzle identifiers\n",
    "    with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n",
    "        identifier_map = json.load(f)\n",
    "        \n",
    "    # Load preds\n",
    "    all_preds = {}\n",
    "    for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n",
    "        preds = torch.load(filename)\n",
    "        for k, v in preds.items():\n",
    "            all_preds.setdefault(k, [])\n",
    "            all_preds[k].append(v)\n",
    "            \n",
    "        del preds\n",
    "\n",
    "    all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n",
    "    \n",
    "    # Remove paddings\n",
    "    mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n",
    "    all_preds = {k: v[mask] for k, v in all_preds.items()}\n",
    "\n",
    "    return identifier_map, all_preds\n",
    "\n",
    "\n",
    "def inverse_aug(name: str, grid: np.ndarray):\n",
    "    if \"_\" not in name:\n",
    "        return grid\n",
    "\n",
    "    trans_id, perm = name.split(\"_\")[-2:]\n",
    "    trans_id = int(trans_id[1:])  # Remove \"t\" letter\n",
    "    inv_perm = np.argsort(list(perm))\n",
    "    \n",
    "    return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n",
    "\n",
    "\n",
    "def grid_hash(grid: np.ndarray):\n",
    "    return hash((grid.tobytes(), grid.shape))\n",
    "\n",
    "\n",
    "@njit\n",
    "def crop(grid: np.ndarray):\n",
    "    # Find maximum-sized rectangle without any EOS token inside.\n",
    "    grid = grid.reshape(30, 30)\n",
    "\n",
    "    max_area = 0\n",
    "    max_size = (0, 0)\n",
    "    nr, nc = grid.shape\n",
    "    \n",
    "    num_c = nc\n",
    "    for num_r in range(1, nr + 1):\n",
    "        # Scan for maximum c\n",
    "        for c in range(1, num_c + 1):\n",
    "            x = grid[num_r - 1, c - 1]\n",
    "            if (x < 2) | (x > 11):\n",
    "                num_c = c - 1\n",
    "                break\n",
    "        \n",
    "        area = num_r * num_c\n",
    "        if area > max_area:\n",
    "            max_area = area\n",
    "            max_size = (num_r, num_c)\n",
    "\n",
    "    return grid[:max_size[0], :max_size[1]] - 2\n",
    "\n",
    "\n",
    "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n",
    "    identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n",
    "    \n",
    "    global_hmap = {}\n",
    "    \n",
    "    # Get puzzles and corresponding answers\n",
    "    puzzle_labels = {}\n",
    "    for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n",
    "        name = identifier_map[identifier]\n",
    "        if \"_\" not in name:   # Not-augmented\n",
    "            puzzle_labels.setdefault(name, {})\n",
    "            \n",
    "            input = crop(input.numpy())\n",
    "            label = crop(label.numpy())\n",
    "\n",
    "            input_hash = grid_hash(input)\n",
    "            label_hash = grid_hash(label)\n",
    "\n",
    "            global_hmap[input_hash] = input\n",
    "            global_hmap[label_hash] = label\n",
    "\n",
    "            assert input_hash not in puzzle_labels[name]\n",
    "            puzzle_labels[name][input_hash] = label_hash\n",
    "            \n",
    "    print (\"Number of puzzles\", len(puzzle_labels))\n",
    "    \n",
    "    # Argmax prediction\n",
    "    preds = all_preds[\"logits\"].argmax(-1)\n",
    "\n",
    "    # Collate\n",
    "    pred_answers = {}\n",
    "    for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n",
    "        name = identifier_map[identifier]\n",
    "        orig_name = name.split(\"_\")[0]\n",
    "        \n",
    "        input = input.numpy()\n",
    "        input_hash = grid_hash(inverse_aug(name, crop(input)))\n",
    "        assert input_hash in puzzle_labels[orig_name]\n",
    "        \n",
    "        pred = inverse_aug(name, crop(pred.numpy()))\n",
    "        pred_hash = grid_hash(pred)\n",
    "        global_hmap[pred_hash] = pred\n",
    "        \n",
    "        pred_answers.setdefault(orig_name, {})\n",
    "        pred_answers[orig_name].setdefault(input_hash, [])\n",
    "        pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n",
    "\n",
    "    # test-1\n",
    "    if visualize:\n",
    "        num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n",
    "        fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n",
    "        \n",
    "        fig_id = 0\n",
    "    \n",
    "    correct = [0 for _ in range(len(Ks))]\n",
    "    for name, tests in puzzle_labels.items():\n",
    "        num_test_correct = [0 for _ in range(len(Ks))]\n",
    "        for input_hash, label_hash in tests.items():\n",
    "            p = pred_answers[name][input_hash]\n",
    "            p_map = {}\n",
    "            \n",
    "            for h, q in p:\n",
    "                p_map.setdefault(h, [0, 0])\n",
    "                p_map[h][0] += 1\n",
    "                p_map[h][1] += q\n",
    "                \n",
    "            for h, stats in p_map.items():\n",
    "                stats[1] /= stats[0]\n",
    "                \n",
    "            p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n",
    "\n",
    "            # 2-vote\n",
    "            for i, k in enumerate(Ks):\n",
    "                ok = False\n",
    "                for h, stats in p_map[:k]:\n",
    "                    ok |= h == label_hash\n",
    "                    \n",
    "                num_test_correct[i] += ok\n",
    "\n",
    "            if visualize:\n",
    "                # Show input and ground truth\n",
    "                axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n",
    "                axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n",
    "                axes[fig_id, 0].axis('off')\n",
    "                \n",
    "                axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n",
    "                axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n",
    "                axes[fig_id, 1].axis('off')\n",
    "                \n",
    "                trial_id = 2\n",
    "                for h, stats in p_map[:2]:\n",
    "                    ans = global_hmap[h]\n",
    "                    \n",
    "                    axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n",
    "                    axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n",
    "                    axes[fig_id, trial_id].axis('off')\n",
    "                    \n",
    "                    trial_id += 1\n",
    "                \n",
    "                fig_id += 1\n",
    "            \n",
    "        # Total correctness\n",
    "        for i in range(len(Ks)):\n",
    "            correct[i] += num_test_correct[i] == len(tests)\n",
    "\n",
    "    for i, k in enumerate(Ks):\n",
    "        print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n",
    "\n",
    "\n",
    "test(visualize=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
